#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# file: mrc_ner_dataset.py

import torch
from torch.utils.data import Dataset
from transformers.tokenization_utils_base import TruncationStrategy
from transformers.utils import logging
import random
import copy
logger = logging.get_logger(__name__)
MULTI_SEP_TOKENS_TOKENIZERS_SET = {"roberta", "camembert", "bart", "mpnet"}
class MRCNERDataset(Dataset):
    """
    MRC NER Dataset
    Args:
        json_path: path to mrc-ner style json
        tokenizer: BertTokenizer
        max_length: int, max length of query+context
        possible_only: if True, only use possible samples that contain answer for the query/context
    """
    def __init__(self, data, tokenizer: None, max_length: int = 512, max_query_length = 64, possible_only=False, pad_to_maxlen=False, is_training=False, rate=1, DA=False, is_chinese=False, sizeonly=False, context_first=False):
        self.all_data = data
        self.is_chinese = is_chinese
        self.rate = rate
        self.prompt()
        if DA and is_training:
            if sizeonly:
                self.convert_examples_to_matching_size()
            else:
                self.convert_examples_to_matching_category()
            self.balance()
        self.context_first = context_first
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.max_query_length = max_query_length
        self.max_context_length = max_length - max_query_length
        self.possible_only = possible_only
        if self.possible_only:
            self.all_data = [
                x for x in self.all_data if x["start_position"]
            ]
        self.pad_to_maxlen = pad_to_maxlen

    def prompt(self):
        num_examples = len(self.all_data)
        num_impossible = len([1 for x in self.all_data if x["impossible"]])
        self.neg_ratio = (num_examples - num_impossible) / num_impossible
        new_datas = []
        for data in self.all_data:
            label = data['entity_label']
            details = data['query']
            context = data['context']
            start_positions = data["start_position"]
            end_positions = data["end_position"]
            words = context.split()
            assert len(words) == len(context.split(" "))
            query = '"{}". {}'.format(label, details)
            span_positions = {"{};{}".format(start_positions[i], end_positions[i]):" ".join(words[start_positions[i]: end_positions[i] + 1]) for i in range(len(start_positions))}
            new_data = {
                'context':words,
                'end_position':end_positions,
                'entity_label':label,
                'impossible':data['impossible'],
                'qas_id':data['qas_id'],
                'query':query,
                'span_position':span_positions,
                'start_position': start_positions,
            }
            # if label == "ORG":
            new_datas.append(new_data)
        self.all_data = new_datas

    def convert_examples_to_matching_category(self):
        examples = self.all_data
        label_possible_count = {}
        label_dict = {}
        sample_dict = {}
        new_examples = []
        global_qas_id = 10000000
        for ie, example in enumerate(examples):
            label_id = example['entity_label']
            if label_id not in label_possible_count:
                label_possible_count[label_id] = 0

            if not example["impossible"]:
                answers = [example['span_position'][x] for x in example['span_position']]
                label_possible_count[label_id] += 1
            else:
                answers = [None]
            if label_id not in label_dict:
                label_dict[label_id] = []
            for answer in answers:
                label_dict[label_id].append((ie, answer))

        max_label_num = int(max([label_possible_count[x] for x in label_possible_count]) * (self.rate + 1))
        for label_id in label_dict:
            item = label_dict[label_id]
            gap = max_label_num - label_possible_count[label_id]
            if gap <= 0:
                continue
            possible_item = [x for x in item if x[1] is not None]
            possible_combinations = [(x,y) for x in possible_item for y in possible_item if x[0] != y[0]]
            possible_sample = random.sample(possible_combinations, min(gap, len(possible_combinations)))
            if self.neg_ratio < 1:
                impossible_combinations = [(x,y) for x in possible_item for y in item if y[1] is None and x != y]
                impossible_sample = random.sample(impossible_combinations, min(gap, len(impossible_combinations)))
            else:
                impossible_sample = []

            pair_data = []
            for pair in possible_sample + impossible_sample:
                seed = pair[0]
                target = pair[1]
                pair_data.append((seed[1], target[:1]))

            pair_data = list({}.fromkeys(pair_data).keys())
            sample_dict[label_id] = pair_data
            for pair in pair_data:
                seed = pair[0]
                target = pair[1]
                target_example = copy.copy(examples[target[0]])
                qas_id = target_example['qas_id'].split(".")[1]
                target_example['qas_id'] = str(global_qas_id) + "." + qas_id
                global_qas_id += 1
                if self.is_chinese:
                    target_example['query'] = seed.replace(" ", "")
                else:
                    target_example['query'] = seed
                new_examples.append(target_example)
        self.all_data = examples + new_examples

    def convert_examples_to_matching_size(self):
        examples = self.all_data
        label_possible_count = {}
        label_dict = {}
        sample_dict = {}
        new_examples = []
        global_qas_id = 10000000
        for ie, example in enumerate(examples):
            label_id = example['entity_label']
            if label_id not in label_possible_count:
                label_possible_count[label_id] = 0

            if not example["impossible"]:
                answers = [example['span_position'][x] for x in example['span_position']]
                label_possible_count[label_id] += 1
            else:
                answers = [None]
            if label_id not in label_dict:
                label_dict[label_id] = []
            for answer in answers:
                label_dict[label_id].append((ie, answer))

        for label_id in label_dict:
            item = label_dict[label_id]
            gap = int(label_possible_count[label_id] * (self.rate + 1))
            if gap <= 0:
                continue
            possible_item = [x for x in item if x[1] is not None]
            possible_combinations = [(x,y) for x in possible_item for y in possible_item if x[0] != y[0]]
            possible_sample = random.sample(possible_combinations, min(gap, len(possible_combinations)))
            if self.neg_ratio < 1:
                impossible_combinations = [(x,y) for x in possible_item for y in item if y[1] is None and x != y]
                impossible_sample = random.sample(impossible_combinations, min(gap, len(impossible_combinations)))
            else:
                impossible_sample = []

            pair_data = []
            for pair in possible_sample + impossible_sample:
                seed = pair[0]
                target = pair[1]
                pair_data.append((seed[1], target[:1]))

            pair_data = list({}.fromkeys(pair_data).keys())
            sample_dict[label_id] = pair_data
            for pair in pair_data:
                seed = pair[0]
                target = pair[1]
                target_example = copy.copy(examples[target[0]])
                qas_id = target_example['qas_id'].split(".")[1]
                target_example['qas_id'] = str(global_qas_id) + "." + qas_id
                global_qas_id += 1
                if self.is_chinese:
                    target_example['query'] = '突出显示与“{}”类似的部分（如果有）。'.format(seed.replace(" ", ""))  #
                else:
                    target_example['query'] = 'Highlight the parts (if any) similar to: ' + seed  #
                new_examples.append(target_example)
        self.all_data = examples + new_examples
    def balance(self):
        examples = self.all_data
        num_examples = len(examples)
        num_impossible = len([1 for x in examples if x["impossible"]])
        neg_keep_frac = (num_examples - num_impossible) / num_impossible
        neg_keep_mask = [x["impossible"] and random.random() < neg_keep_frac for x in examples]

        # keep all positive examples and subset of negative examples
        keep_mask = [(not examples[i]["impossible"]) or neg_keep_mask[i] for i in range(len(examples))]
        keep_indices = [i for i in range(len(keep_mask)) if keep_mask[i]]
        new_examples = [examples[i] for i in keep_indices]
        self.all_data = new_examples

    def __len__(self):
        return len(self.all_data)

    def __getitem__(self, item):
        """
        Args:
            item: int, idx
        Returns:
            tokens: tokens of query + context, [seq_len]
            attention_mask: attention mask, 1 for token, 0 for padding, [seq_len]
            token_type_ids: token type ids, 0 for query, 1 for context, [seq_len]
            label_mask: label mask, 1 for counting into loss, 0 for ignoring. [seq_len]
            match_labels: match labels, [seq_len, seq_len]
        """
        data = self.all_data[item]
        tokenizer = self.tokenizer



        query = data["query"]
        context = data["context"]
        start_positions = data["start_position"]
        end_positions = data["end_position"]

        tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower()
        sequence_added_tokens = (
            tokenizer.model_max_length - tokenizer.max_len_single_sentence + 1
            if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET
            else tokenizer.model_max_length - tokenizer.max_len_single_sentence
        )



        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, token) in enumerate(context):
            orig_to_tok_index.append(len(all_doc_tokens))
            if tokenizer.__class__.__name__ in [
                "RobertaTokenizer",
                "LongformerTokenizer",
                "BartTokenizer",
                "RobertaTokenizerFast",
                "LongformerTokenizerFast",
                "BartTokenizerFast",
            ]:
                sub_tokens = tokenizer.tokenize(token, add_prefix_space=True)
            elif tokenizer.__class__.__name__ in [
                'BertTokenizer'
            ]:
                sub_tokens = tokenizer.tokenize(token)
            elif tokenizer.__class__.__name__ in [
                'BertWordPieceTokenizer'
            ]:
                sub_tokens = tokenizer.encode(token, add_special_tokens=False).tokens
            else:
                sub_tokens = tokenizer.tokenize(token)
            for sub_token in sub_tokens:
                tok_to_orig_index.append(i)
                all_doc_tokens.append(sub_token)


        tok_start_positions = [orig_to_tok_index[x] for x in start_positions]
        tok_end_positions = []
        for x in end_positions:
            if x < len(context) - 1:
                tok_end_positions.append(orig_to_tok_index[x + 1] - 1)
            else:
                tok_end_positions.append(len(all_doc_tokens) - 1)


        if self.pad_to_maxlen:
            truncation = TruncationStrategy.ONLY_SECOND.value
            padding_strategy = "max_length"
        else:
            truncation = TruncationStrategy.ONLY_SECOND.value
            padding_strategy = "do_not_pad"

        if self.context_first:
            truncated_context = tokenizer.encode(
                all_doc_tokens, add_special_tokens=False, truncation=True, max_length=self.max_context_length
            )
            encoded_dict = tokenizer.encode_plus(  # TODO(thom) update this logic
                truncated_context,
                query,
                truncation=truncation,
                padding=padding_strategy,
                max_length=self.max_length,
                return_overflowing_tokens=True,
                return_token_type_ids=True,
            )
            tokens = encoded_dict['input_ids']
            type_ids = encoded_dict['token_type_ids']
            attn_mask = encoded_dict['attention_mask']
            # find new start_positions/end_positions, considering
            # 1. we add cls token at the beginning
            doc_offset = 1
            new_start_positions = [x + doc_offset for x in tok_start_positions if
                                   (x + doc_offset) <= self.max_context_length]
            new_end_positions = [x + doc_offset if (x + doc_offset) <= self.max_context_length else self.max_context_length for x
                                 in tok_end_positions]
            new_end_positions = new_end_positions[:len(new_start_positions)]
            label_mask = [0] * doc_offset + [1] * len(truncated_context) + [0] * (len(tokens) - len(truncated_context) - 1)
        else:
            truncated_query = tokenizer.encode(
                query, add_special_tokens=False, truncation=True, max_length=self.max_query_length
            )
            encoded_dict = tokenizer.encode_plus(  # TODO(thom) update this logic
                truncated_query,
                all_doc_tokens,
                truncation=truncation,
                padding=padding_strategy,
                max_length=self.max_length,
                return_overflowing_tokens=True,
                return_token_type_ids=True,
            )
            tokens = encoded_dict['input_ids']
            type_ids = encoded_dict['token_type_ids']
            attn_mask = encoded_dict['attention_mask']

            # find new start_positions/end_positions, considering
            # 1. we add query tokens at the beginning
            # 2. special tokens
            doc_offset = len(truncated_query) + sequence_added_tokens
            new_start_positions = [x + doc_offset for x in tok_start_positions if (x + doc_offset) < self.max_length - 1]
            new_end_positions = [x + doc_offset if (x + doc_offset) < self.max_length - 1 else self.max_length - 2 for x in
                                 tok_end_positions]
            new_end_positions = new_end_positions[:len(new_start_positions)]

            label_mask = [0] * doc_offset + [1] * (len(tokens) - doc_offset - 1) + [0]




        assert all(label_mask[p] != 0 for p in new_start_positions)
        assert all(label_mask[p] != 0 for p in new_end_positions)

        assert len(label_mask) == len(tokens)

        seq_len = len(tokens)
        match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long)
        for start, end in zip(new_start_positions, new_end_positions):
            if start >= seq_len or end >= seq_len:
                continue
            match_labels[start, end] = 1

        return [
            torch.LongTensor(tokens),
            torch.LongTensor(attn_mask),
            torch.LongTensor(type_ids),
            torch.LongTensor(label_mask),
            match_labels,
        ]

def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
    """Returns tokenized answer spans that better match the annotated answer."""
    tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))

    for new_start in range(input_start, input_end + 1):
        for new_end in range(input_end, new_start - 1, -1):
            text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
            if text_span == tok_answer_text:
                return (new_start, new_end)

    return (input_start, input_end)